Skip to content

Conversation

jcaip
Copy link
Contributor

@jcaip jcaip commented Sep 26, 2025

This PR adds in support for quantizing nn.Parameter to quantize_.

ModuleFqnToConfig has been renamed to FqnToConfig, which now accepts both module fqn and parameter fqns. ModuleFqnToConfig has been aliased to maintain BC.

bc-breaking changes

  1. Passing filter_fn=None to quantize_ has changed semantics.

Previously, when we passed in filter_fn=None , we would just assign it to be _is_linear. So passing in None and _is_linear was functionally the same.

Now, None and _is_linear have different semantics. None will ignore filter_fn completely and just use the provided FqnToConfig to quantize the model.

before

# these are equivalent
model = nn.Linear(128, 128)

quantize(model, config, filter_fn=None) = quantize(mode, config, filter_fn=_is_linear)

assert isinstance(model.weight, TorchAOBaseTensor)

after

# these are equivalent
model = nn.Linear(128, 128)

quantize(model, config, filter_fn=None)
quantize(mode, config, filter_fn=_is_linear)

assert isinstance(model.weight, TorchAOBaseTensor)

This is needed because we have non-linear non-weight modules we want to quantize, and we need a module to both pass filter_fn and be specified in the FqnToConfig.

  1. The default value of filter_fn for quantize_ has changed from None -> _is_linear.

To maintain the behavior of quantize(model, config), we have changed the default value of filter_fn from None to _is_linear explicitly.

  1. Passing in filter_fn=None with _default in FqnToConfig now raises a ValueError.

Before this would default to _is_linear, but now that filter_fn=None has a different meaning _default is only supported when a filter_fn is specified.

  1. Passing in non-default filter_fn with FqnToConfig will raise a ValueError

To encourage users to use avoid mixing filter_fn and FqnToConfig configuration, quantize_ will now throw a ValueError if filter_fn is not _is_linear or None and is a custom filter_fn.

API examples

For example, a toy nn.Linear model,

class MyModel(nn.Module()):
  linear1 = nn.Linear(128, 128)
  linear2 = nn.Linear(128, 128)

model = MyModel()

The keys to FqnToConfig can be one of the following (in order of precedence):

  1. exact module FQN - We can quantize the weight of the first linear as follows
quant_config = FqnToConfig({
    "linear1": Float8DynamicActivationFloat8WeightConfig(
        granularity=PerRow(),
    ),
})
  1. regex that matches module FQN (prepended by re:)
quant_config = FqnToConfig({
    "re:linear*": Float8DynamicActivationFloat8WeightConfig(
        granularity=PerRow(),
    ),
})
  1. exact parameter FQN
quant_config = FqnToConfig({
    "linear1.weight": Float8DynamicActivationFloat8WeightConfig(
        granularity=PerRow(),
    ),
})
  1. regex that matches parameter FQN (prepended by re:)
quant_config = FqnToConfig({
    "re:linear*.weight": Float8DynamicActivationFloat8WeightConfig(
        granularity=PerRow(),
    ),
})
  1. _default, only when filter_fn is specified
quant_config = FqnToConfig({
    "_default": Float8DynamicActivationFloat8WeightConfig(
        granularity=PerRow(),
    ),
})

To enable support for parameter fqn for a paticular config, we need to add the parameter_name kwarg into the config signature, and update CUSTOM_PARAM_QUANTIZATION_SUPPOTED_CONFIGS. See the changes here for more details.

Float8DynamicActivationFloat8WeightConfig has been enabled by this PR, but other configs will throw an NotImplementedError.

Test Plan

  1. unit tests for new config:
pytest test/quantization/test_quant_api.py::TestModuleOrParamFqnToConfig
  1. regression test for ModuleFqnToConfig
pytest test/quantization/test_quant_api.py -k test_module_fqn_to_config
  1. Make sure that we can load old HF checkpoints to maintain BC, run this

How do our configs translate for MoEs?

Currently, we define a bunch of configs that are for dense nn.Linear modules, how do these configs translate in the case of MoE inference?

Some background on MoE inference

There are two ways that forwards is implemented for MoE

  • For loop of nn.Linear - In this case, we break down the 3d weight x activation matmul into a for loop of 2d weight x activation matmuls. This can be seen here.

In this case, I argue that the semantics of the configs do not change at all from the normal nn.Linear case, as we are just doing a bunch of normal 2d linear matmuls.

  • bmm/grouped mm on the 3d weights / activations directly.

For this case, we'd need to add additional op support (bmm) for forwards. Depending on whether the subclass is an AQT subclass or non AQT subclass this will be added differently.

I plan to only support parameter quantization for non-AQT subclasses, my reasoning being that those are the most popular / important configs anyway (Float8Dynamic, Int4WeightOnly).

Below is a breakdown of what Configs map to AQT / non-AQT subclasses:

not using AQT AffineQuantizedTensor
Float8DynamicActivationFloat8WeightConfig FPXWeightOnlyConfig
Float8DynamicActivationInt4WeightConfig Float8WeightOnlyConfig
Float8StaticActivationFloat8WeightConfig Float8DynamicActivationFloat8SemiSparseWeightConfig
Int4WeightOnlyConfig (v2) GemliteUIntXWeightOnlyConfig
Int4DynamicActivationInt4WeightConfig
Int8DynamicActivationInt4WeightConfig
Int8DynamicActivationInt8WeightConfig
Int8WeightOnlyConfig
IntxWeightOnlyConfig
UIntXWeightOnlyConfig

For these the majority of the semantics remain the same, the only semantics that really changes is PerRow granularity. and there's a very natural extension of PerRow to the 3d case (apply on the last dimension).

I took a look at the keys of the non-AQT configs below and what they would mean for MoEs.

Float8DynamicActivationFloat8WeightConfig

[('activation_dtype', <class 'torch.dtype'>),
 ('weight_dtype', <class 'torch.dtype'>),
 ('granularity',
  typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow'), typing.List[typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow')]], NoneType]),
 ('mm_config', typing.Optional[torchao.float8.inference.Float8MMConfig]),
 ('activation_value_lb', typing.Optional[float]),
 ('activation_value_ub', typing.Optional[float]),
 ('kernel_preference', <enum 'KernelPreference'>),
 ('set_inductor_config', <class 'bool'>),
 ('version', <class 'int'>)]

activation_dtype, weight_dtype, activation_value_lb, activation_value_ub all do not change meaning semantically.
granularity=PerTensor() does not change semantic meaning - we still use a single tensor to scale the entire weight tensor.
granularity=PerRow() does change meaning - we now calculate a scale for each row for the last dimension [-1] i.e for a weight of (E, N, K) we would expect PerRow to create scales of block size (1, 1, K).
mm_config kernel_preference and set_inductor_config stay the same as well.

Float8StaticActivationFloat8WeightConfig

[('scale', <class 'torch.Tensor'>),
 ('activation_dtype', <class 'torch.dtype'>),
 ('weight_dtype', <class 'torch.dtype'>),
 ('granularity',
  typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow'), typing.Tuple[typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow')], typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow')]], NoneType]),
 ('mm_config', typing.Optional[torchao.float8.inference.Float8MMConfig]),
 ('set_inductor_config', <class 'bool'>)]

scale should be passed in as a 3d tensor instead of a 2d tensor in the case of PerRow granularity

Float8DynamicActivationInt4WeightConfig

[('int4_packing_format', <enum 'Int4PackingFormat'>)]

int4_packing_format - Only "preshuffled" is supported and Int4PreshuffledTensor supports 3d weights.

Int4WeightOnlyConfig

[('group_size', <class 'int'>),
 ('layout',
  typing.Optional[torchao.dtypes.uintx.tensor_core_tiled_layout.TensorCoreTiledLayout]),
 ('use_hqq', <class 'bool'>),
 ('zero_point_domain',
  typing.Optional[torchao.quantization.quant_primitives.ZeroPointDomain]),
 ('set_inductor_config', <class 'bool'>),
 ('preserve_zero', typing.Optional[bool]),
 ('int4_packing_format', <enum 'Int4PackingFormat'>),
 ('int4_choose_qparams_algorithm', <enum 'Int4ChooseQParamsAlgorithm'>),
 ('version', <class 'int'>)]

group_size, int4_packing_format, int4_choose_qparams_algorithm, set_inductor_config are the only things that are set for v2 config,

I don't think these semantics of these change, although there are some packing formats that do not support 3d weights. It looks like (Int4PackingFormat.PLAIN_INT32, Int4PackingFormat.MARLIN_SPARSE).

Summary:

This PR adds in a simple 2d and 3d moe implementation and tests
`quantize_` on them to see if we get the same results.

Test Plan:

```
pytest test/prototype/test_parameter.py -k test_quantize_parameter
```

Reviewers:

Subscribers:

Tasks:

Tags:
Copy link

pytorch-bot bot commented Sep 26, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3083

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 92f9774 with merge base 16c7d09 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 26, 2025
@jcaip jcaip requested review from jerryzh168 and vkuzo September 26, 2025 21:00
@jerryzh168
Copy link
Contributor

current AOBaseConfig is more for linear weights, can it be extended to param config cleanly?

@vkuzo
Copy link
Contributor

vkuzo commented Sep 29, 2025

Add in ParamFqnToConfig config
This new config is very similar to ModuleFqnToConfig except it takes in nn.Parameter FQNs and also supports regexs.

Would it work to stick with ModuleFqnToConfig and update its meaning, to avoid introducing a new object with a lot of similarities with the old object? Pseudocode of what it could do:

def handle_module(model, fqn, config):
    if has_parameter(model, fqn):
        ... new behavior for parameters, apply parameter swap config ...
    elif has_parameter(model, fqn + '.weight'):
        ... old behavior, apply parameter swap config ...
    elif has_module(model, fqn):
        ... old behavior, apply module swap ...

@jcaip
Copy link
Contributor Author

jcaip commented Sep 29, 2025

Would it work to stick with ModuleFqnToConfig and update its meaning, to avoid introducing a new object with a lot of similarities with the old object?

Yeah, we can do this. Do you think we should keep the ModuleFqnToConfig name? It's a little confusing I feel to pass in parameter fqn but it's also being used by huggingface and vllm so I think it would be better to keep it as is.

@jcaip
Copy link
Contributor Author

jcaip commented Sep 29, 2025

current AOBaseConfig is more for linear weights, can it be extended to param config cleanly?

Yes I believe so, especially in the case of the Config object itself. We attach everything to the weight parameter for nn.Linear, so this allows us to specify the parameter name instead of assuming it's "weight".

The only thing that does not map cleanly IMO is the module_registration:

        # non user facing code
        @register_quantize_module_handler(WorkflowFooConfig)
        def _transform(
            mod: torch.nn.Module,
            config: WorkflowFooConfig,
        ) -> torch.nn.Module:
            # the transform is implemented here, usually a tensor sublass
            # weight swap or a module swap

I think we should define the transform for parameters as the base case (aka @register_quantize_handler) , and use that for the module flow (assuming the parameter is module.weight), since it's the more general case.

@vkuzo
Copy link
Contributor

vkuzo commented Sep 29, 2025

Do you think we should keep the ModuleFqnToConfig name? It's a little confusing I feel to pass in parameter fqn but it's also being used by huggingface and vllm so I think it would be better to keep it as is.

IMO we should change the current name and keep the old name for BC:

ParamOrModuleFqnToConfig = ...

# for bc
ModuleFqnToConfig = ParamOrModuleFqnToConfig

@vkuzo
Copy link
Contributor

vkuzo commented Sep 29, 2025

I think we should define the transform for parameters as the base case

To me it seems that the transform has to be for modules, because it is inplace. User can target a parameter if they want to, but the transform function always runs on a module that owns the parameter.

# skip if not direct child
if "." not in name:
for pattern in config.param_fqn_to_config:
if re.match(pattern, f"{fqn}.{name}"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so it applies to all params, regardless of what it is? e.g. bias? should we be more specific in what people are configuring?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should consider the regex syntax separately, I can remove from this PR.

One thing I would like would be for quantize_ log the modules/params it's swapping so it's easy to see what the difference is.

@andrewor14
Copy link
Contributor

Does this mean we need to refactor all supported configs to use this structure?

@register_quantized_param_handler(config)
def _float8_dynamic_activation_float8_weight_quantize_tensor(...):
    # returns quantized tensor

def _float8_dynamic_activation_float8_weight_transform(...):
    module.weight = _float8_dynamic_activation_float8_weight_quantize_tensor(...)
    return module

Copy link
Contributor

@andrewor14 andrewor14 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall, just one main question about how the default filter_fn interacts with the config

(fqn-configuration)=
### 3. FQN Configuration

For granular control, use `ModuleFqnToConfig`:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we also document this in serving.md, can you update that doc as well?

assert isinstance(model.shared_expert.gate_proj.weight, Float8Tensor)
assert model.shared_expert.gate_proj.weight.scale.numel() == 1

def test_quantize_modle_exact_match_preference(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: typo modle

"""
torch._C._log_api_usage_once("torchao.quantization.quantize_")

filter_fn = _is_linear if filter_fn is None else filter_fn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this default filter_fn going to have unexpected consequences if people are using FqnToConfig? E.g. let's say someone literally just wants to quantize a very specific parameter:

quantize_(model, FqnToConfig({"layers.0.some.parameter": Int4WeightOnlyConfig()}))

If I'm reading the code correctly, right now we do the replacement if either (1) we match the filter_fn, or (2) we match the fqn. Would the above unexpectedly quantize all the other linear layers in the model?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, replacement won't do anything as the other linear layers aren't specified in the config. I can add a test for this though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think it would be good to verify this, from the code it seems we do the replacement if we match either the filter_fn or the config (not and). Would also be good to clearly document the semantics of filter_fn in the docstring in this case

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I think the semantic should be

  • if both fqn_to_config and filter_fn specified, both have to match for config to be applied (AND, not OR)
  • else, use whichever one is applied

it seems like we should consider breaking BC here and change the default filter_fn to is_linear, so that if user passes in filter_fn == None then only fqn_to_config is applied?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my mind, if someone specifies a fqn in the config, it's pretty clear that they want to quantize it. So I think AND is kind of a footgun here, especially if the default filter_fn is is_linear. i.e. First time user wants to quantize a parameter, adds an entry to FqnToConfig, and the new param doesn't get quantized because the default filter_fn is is_linear. I guess we can just throw a warning in this instance though.

cc @jerryzh168 what do you think? I'll defer to whatever's most popular with the team.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me, ill update the pr

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed on removing filter_fn longer term

I think it is used pretty widely though, so maybe not in this PR and we do it separately with a proper deprecation? We can punt in this PR by just throwing an exception if fqn_to_config is provided along with a non-default filter_fn.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

filter_fn has a lot of internal uses, and it's how many users apply quantization/QAT to linear and embedding separately today. We should do a careful deprecation of this and make sure existing use cases have a good alternative

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andrewor14 , any thoughts on "We can punt in this PR by just throwing an exception if fqn_to_config is provided along with a non-default filter_fn."?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can punt in this PR by just throwing an exception if fqn_to_config is provided along with a non-default filter_fn

Yeah sounds good to me

regex patterns (as strings) to quantization configurations.
The patterns can be one of the follows:
(1). fully qualified name (fqn) of module or paramter or
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: paramter

`module_fqn_to_config`: typing.OrderedDict[str, Optional[AOBaseConfig]]: an
ordered dictionary from
(1). fully qualified name (fqn) of module or
module_fqn_to_config (OrderedDict[str, Optional[AOBaseConfig]]): An ordered dictionary mapping
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the docstring still references the old arg name I think

torch._C._log_api_usage_once("torchao.quantization.FqnToConfig")
if len(self.module_fqn_to_config) > 0 and len(self.fqn_to_config) > 0:
warnings.warn(
"Both module_fqn_to_config and fqn_to_config are specified, only fqn_to_config will be used"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this is going to be a silent error for some users, should we just ban this case for simplicity? It's not for BC

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we should just ValueError here.

warnings.warn(
"Both module_fqn_to_config and fqn_to_config are specified, only fqn_to_config will be used"
)
if len(self.module_fqn_to_config) > 0 and len(self.fqn_to_config) == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if you throw an error above then this can become:

if len(self.module_fqn_to_config) > 0:
    assert len(self.fqn_to_config) == 0
    self.fqn_to_config = self.module_fqn_to_config

and you don't need the rest of the cases (probably don't need to update self.module_fqn_to_config to match self.fqn_to_config?)

return handler(module, c)

return module
def select_module_if_filter_fn_or_contains_params_matching_pattern(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

private?

Copy link
Contributor

@andrewor14 andrewor14 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! I'll let Jerry/Vasiliy stamp since they reviewed this in more detail

Args:
fqn (str): The fully qualified name to match against the config patterns.
config (FqnToConfig): The FqnToConfig object containing mapping of FQNs or regex patterns to quantization configs.
torchao/quantization/quant_api.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove?

"""
torch._C._log_api_usage_once("torchao.quantization.quantize_")

filter_fn = _is_linear if filter_fn is None else filter_fn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think it would be good to verify this, from the code it seems we do the replacement if we match either the filter_fn or the config (not and). Would also be good to clearly document the semantics of filter_fn in the docstring in this case

return found, c


def _select_module_if_filter_fn_or_contains_params_matching_pattern(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO should be AND, not OR

_module_fqn_to_config_handler,
filter_fn,
_fqn_to_config_handler,
partial(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like we are passing one callable and one callable wrapping a callable into a fuction, seems a bit hard to follow. Have we considered just writing this directly instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can write this as a lambda, if that's a bit clearer to you?

lambda mod, fqn: filter_fn(mod, fqn) and select_with_module(mod, fqn, config=config)

PerRow,
PerTensor,
)
from .GPTQ import Int4WeightOnlyGPTQQuantizer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are the style changes intended? If yes, can we separate into a different PR?

@jerryzh168
Copy link
Contributor

can you add a BC breaking Notes section since it's breaking BC?

@jcaip jcaip added the topic: bc-breaking Use this tag if this PR breaks backward compatibility label Oct 16, 2025
@jcaip jcaip requested review from jerryzh168 and vkuzo October 16, 2025 14:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: bc-breaking Use this tag if this PR breaks backward compatibility topic: new feature Use this tag if this PR adds a new feature

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants